// Copyright (C) INFINI Labs & INFINI LIMITED.
//
// The INFINI Framework is offered under the GNU Affero General Public License v3.0
// and as commercial software.
//
// For commercial licensing, contact us at:
//   - Website: infinilabs.com
//   - Email: hello@infini.ltd
//
// Open Source licensed under AGPL V3:
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

// Copyright 2018 Elasticsearch BV
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

//➜  model git:(master) ✗ ../../../../cmd/generate-fastjson/generate-fastjson -o generated.go .
//generated 3 methods in "generated.go"

import (
	"bytes"
	"flag"
	"fmt"
	"go/ast"
	"go/format"
	"go/token"
	"go/types"
	"io"
	"log"
	"os"
	"reflect"
	"sort"
	"strings"

	"golang.org/x/tools/go/packages"
)

const (
	fastjsonPath  = "infini.sh/framework/lib/fastjson_marshal"
	isZeroMethod  = "isZero"
	marshalMethod = "MarshalFastJSON"
)

var (
	force   bool
	outfile string
)

func init() {
	flag.BoolVar(&force, "f", false, "remove the output file if it exists")
	flag.StringVar(&outfile, "o", "-", "file to which output will be written")
	flag.Usage = func() {
		fmt.Fprintf(os.Stderr, "Usage: %s <package>\n", os.Args[0])
		flag.PrintDefaults()
	}
}

func main() {
	flag.Parse()
	if flag.NArg() != 1 {
		flag.Usage()
		os.Exit(1)
	}
	if outfile != "-" {
		if _, err := os.Stat(outfile); err == nil {
			if force {
				if err := os.Remove(outfile); err != nil {
					log.Fatal(err)
				}
			} else {
				fmt.Fprintf(os.Stderr, "%s already exists, and -f not specified; aborting\n", outfile)
				os.Exit(2)
			}
		}
	}

	cfg := &packages.Config{
		Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo,
	}
	pkgs, err := packages.Load(cfg, flag.Arg(0))
	if err != nil {
		fmt.Fprintf(os.Stderr, "load: %v\n", err)
		os.Exit(1)
	}
	if packages.PrintErrors(pkgs) > 0 {
		os.Exit(1)
	}
	pkg := pkgs[0]

	var buf bytes.Buffer
	fmt.Fprintf(&buf, `
// Code generated by "generate-fastjson". DO NOT EDIT.

package %s

import (
	%q
)
`[1:], pkg.Types.Name(), fastjsonPath)

	var generated int
	for _, f := range pkg.Syntax {
		for _, decl := range f.Decls {
			genDecl, ok := decl.(*ast.GenDecl)
			if !ok || genDecl.Tok != token.TYPE {
				continue
			}
			for _, spec := range genDecl.Specs {
				typeSpec, ok := spec.(*ast.TypeSpec)
				if !ok {
					continue
				}
				obj := pkg.TypesInfo.Defs[typeSpec.Name]
				if obj == nil || !obj.Exported() {
					continue
				}
				typeName := obj.(*types.TypeName)
				named := typeName.Type().(*types.Named)
				if !hasMethod(named, marshalMethod) {
					generate(&buf, named)
					generated++
				}
			}
		}
	}

	formatted, err := format.Source(buf.Bytes())
	if err != nil {
		fmt.Println(buf.String())
		log.Fatal(err)
	}

	var out io.Writer = os.Stdout
	if outfile != "-" {
		f, err := os.Create(outfile)
		if err != nil {
			log.Fatal(err)
		}
		defer f.Close()
		out = f
	}
	if _, err := out.Write(formatted); err != nil {
		log.Fatal(err)
	}
	if outfile != "" {
		fmt.Fprintf(os.Stderr, "generated %d methods in %q\n", generated, outfile)
	}
}

func generate(w *bytes.Buffer, named *types.Named) {
	structType, ok := named.Underlying().(*types.Struct)
	if !ok {
		panic(fmt.Errorf("unhandled type %T", named.Underlying()))
	}

	origw := w
	w = new(bytes.Buffer)
	defer func() {
		fmt.Fprintf(origw, "\nfunc (v *%s) %s(w *fastjson_marshal.Writer) error {\n", named.Obj().Name(), marshalMethod)

		// Hypothetically you could create a type whose names contains
		// "firstErr" which would force this. No big deal if the var is
		// never written to, this is just for aesthetics.
		mayError := strings.Contains(w.String(), "firstErr")
		if mayError {
			fmt.Fprintln(origw, "var firstErr error")
		}
		fmt.Fprintln(origw, `w.RawByte('{')`)
		w.WriteTo(origw)
		fmt.Fprintln(origw, `w.RawByte('}')`)
		if mayError {
			fmt.Fprintln(origw, "return firstErr")
		} else {
			fmt.Fprintln(origw, "return nil")
		}
		fmt.Fprintln(origw, "}")
	}()

	numFields := structType.NumFields()
	structFields := make([]structField, 0, numFields)
	for i := 0; i < numFields; i++ {
		structField, ok := makeStructField(structType, i)
		if !ok {
			continue
		}
		structFields = append(structFields, structField)
	}
	sort.Slice(structFields, func(i, j int) bool {
		// Put non-omitempty fields first, so we can elide
		// the runtime "first" tracking.
		switch {
		case !structFields[i].omitempty && structFields[j].omitempty:
			return true
		case structFields[i].omitempty && !structFields[j].omitempty:
			return false
		}
		return structFields[i].jsonName < structFields[j].jsonName
	})

	checkFirst := len(structFields) > 1 && structFields[0].omitempty
	if checkFirst {
		fmt.Fprintln(w, "first := true")
	}
	for i, f := range structFields {
		if f.omitempty {
			fmt.Fprintf(w, "if %s {", isNonZero("v."+f.fieldName, f.fieldType))
		}
		prefix := fmt.Sprintf(",%q:", f.jsonName)
		if checkFirst {
			fmt.Fprintf(w, `
const prefix = %q
if first {
	first = false
	w.RawString(prefix[1:])
} else {
	w.RawString(prefix)
}
`[1:], prefix)
		} else {
			if i == 0 {
				prefix = prefix[1:]
			}
			fmt.Fprintf(w, "w.RawString(%q)\n", prefix)
		}
		var nillable bool
		if !f.omitempty {
			// For nillable types (pointer, slice, map, interface),
			// emit a null check to write "null".
			switch f.fieldType.Underlying().(type) {
			case *types.Pointer:
				nillable = true
			case *types.Slice:
				nillable = true
			case *types.Map:
				nillable = true
			case *types.Interface:
				nillable = true
			}
			if nillable {
				fmt.Fprintf(w, `
if v.%s == nil {
	w.RawString("null")
} else {
`[1:], f.fieldName)
			}
		}
		generateValue(w, "v."+f.fieldName, f.fieldType)
		if f.omitempty || nillable {
			fmt.Fprintln(w, "}")
		}
	}
}

func generateValue(w *bytes.Buffer, expr string, exprType types.Type) {
	if named, ok := exprType.(*types.Named); ok {
		if hasMethod(named, marshalMethod) {
			fmt.Fprintf(w, `
if err := %s.%s(w); err != nil && firstErr == nil {
	firstErr = err
}
`[1:], expr, marshalMethod)
			return
		}
		exprType = named.Underlying()
	}

	switch t := exprType.(type) {
	case *types.Pointer:
		generatePointerValue(w, expr, t)
	case *types.Slice:
		generateSliceValue(w, expr, t)
	case *types.Basic:
		generateBasicValue(w, expr, t)
	case *types.Map:
		generateMapValue(w, expr, t)
	case *types.Interface:
		generateInterfaceValue(w, expr, t)
	case *types.Struct:
		generateStructValue(w, expr, t)
	default:
		panic(fmt.Errorf("unhandled type %T", t))
	}
}

func generatePointerValue(w *bytes.Buffer, expr string, exprType *types.Pointer) {
	elem := exprType.Elem()
	switch t := elem.Underlying().(type) {
	case *types.Basic:
		generateBasicValue(w, "*"+expr, t)
	case *types.Struct:
		generateStructValue(w, expr, t)
	default:
		panic(fmt.Errorf("unhandled type %T", exprType))
	}
}

func generateBasicValue(w *bytes.Buffer, expr string, exprType *types.Basic) {
	convert := func(t string) {
		expr = fmt.Sprintf("%s(%s)", t, expr)
	}
	var method string
	switch k := exprType.Kind(); k {
	case types.Bool:
		method = "Bool"
	case types.Int, types.Int8, types.Int16, types.Int32:
		convert("int64")
		method = "Int64"
	case types.Int64:
		method = "Int64"
	case types.Uint, types.Uint8, types.Uint16, types.Uint32:
		convert("uint64")
		method = "Uint64"
	case types.Uint64:
		method = "Uint64"
	case types.Float32:
		method = "Float32"
	case types.Float64:
		method = "Float64"
	case types.String:
		method = "String"
	default:
		panic(fmt.Errorf("unhandled basic kind %q", types.Typ[k]))
	}
	fmt.Fprintf(w, "w.%s(%s)\n", method, expr)
}

func generateStructValue(w *bytes.Buffer, expr string, exprType *types.Struct) {
	fmt.Fprintf(w, `
if err := %s.%s(w); err != nil && firstErr == nil {
	firstErr = err
}
`[1:], expr, marshalMethod)
}

func generateInterfaceValue(w *bytes.Buffer, expr string, exprType *types.Interface) {
	fmt.Fprintf(w, `
if err := fastjson_marshal.Marshal(w, %s); err != nil && firstErr == nil {
	firstErr = err
}
`[1:], expr)
}

func generateSliceValue(w *bytes.Buffer, expr string, exprType *types.Slice) {
	fmt.Fprintf(w, `
w.RawByte('[')
for i, v := range %s {
	if i != 0 {
		w.RawByte(',')
	}
`[1:], expr)
	generateValue(w, "v", exprType.Elem())
	fmt.Fprintln(w, `
}
w.RawByte(']')`[1:])
}

func generateMapValue(w *bytes.Buffer, expr string, exprType *types.Map) {
	fmt.Fprintf(w, `
w.RawByte('{')
{
	first := true
	for k, v := range %s {
		if first {
			first = false
		} else {
			w.RawByte(',')
		}
`[1:], expr)
	generateValue(w, "k", exprType.Key())
	fmt.Fprintln(w, "w.RawByte(':')")
	generateValue(w, "v", exprType.Elem())
	fmt.Fprintln(w, `
}
}
w.RawByte('}')`[1:])
}

func isNonZero(expr string, t types.Type) string {
	if named, ok := t.(*types.Named); ok {
		if hasMethod(named, isZeroMethod) {
			return fmt.Sprintf("!%s.%s()", expr, isZeroMethod)
		}
		t = named.Underlying()
	}
	zero := "nil"
	switch t := t.(type) {
	case *types.Pointer:
	case *types.Slice:
	case *types.Map:
	case *types.Interface:
	case *types.Struct:
	case *types.Basic:
		switch t.Kind() {
		case types.String:
			zero = `""`
		case types.Bool:
			zero = "false"
		default:
			zero = "0"
		}
	default:
		fmt.Println(expr, t)
		panic(fmt.Errorf("unhandled type %T", t))
	}
	return fmt.Sprintf("%s != %s", expr, zero)
}

type structField struct {
	fieldName string
	jsonName  string
	fieldType types.Type
	omitempty bool
}

func makeStructField(structType *types.Struct, i int) (structField, bool) {
	fieldVar := structType.Field(i)
	if !fieldVar.Exported() {
		return structField{}, false
	}
	var omitempty bool
	fieldName := fieldVar.Name()
	jsonName := fieldName
	fieldTag := reflect.StructTag(structType.Tag(i))
	jsonTag, ok := fieldTag.Lookup("json")
	if ok {
		if jsonTag == "-" {
			return structField{}, false
		}
		name := jsonTag
		comma := strings.IndexRune(jsonTag, ',')
		if comma >= 0 {
			name = jsonTag[:comma]
			switch jsonTag[comma+1:] {
			case "": // special case for `json:"-,"`
			case "omitempty":
				omitempty = true
			default:
				panic("unhandled json tag: " + jsonTag)
			}
		}
		if name != "" {
			jsonName = name
		}
	}
	return structField{
		fieldName: fieldName,
		jsonName:  jsonName,
		fieldType: fieldVar.Type(),
		omitempty: omitempty,
	}, true
}

func hasMethod(named *types.Named, method string) bool {
	for i := named.NumMethods() - 1; i >= 0; i-- {
		if named.Method(i).Name() == method {
			return true
		}
	}
	return false
}
